from unittest.mock import patch
import numpy as np
from scipy.integrate import trapezoid
from qutip import (destroy, propagator, Propagator, propagator_steadystate,
                   steadystate, tensor, qeye, basis, QobjEvo, sesolve,
                   liouvillian, rand_dm, enr_identity, enr_destroy)
import qutip
import pytest
from qutip.solver.brmesolve import BRSolver
from qutip.solver.mesolve import MESolver, mesolve
from qutip.solver.sesolve import SESolver
from qutip.solver.mcsolve import MCSolver


def testPropHOB():
    a = destroy(5)
    H = a.dag()*a
    U = propagator(H, 1)
    U2 = (-1j * H).expm()
    assert (U - U2).norm('max') < 1e-4


def testPropObj():
    opt = {"method": "dop853"}
    a = destroy(5)
    H = a.dag()*a
    U = Propagator(H, c_ops=[a], options=opt, memoize=5, tol=1e-5)
    # Few call to fill the stored propagators.
    U(0.5), U(0.25), U(0.75), U(1), U(-1), U(-.5)
    assert len(U.times) == 5
    assert (U(1) - propagator(H, 1, [a])).norm('max') < 1e-4
    assert (U(0.5) - propagator(H, 0.5, [a])).norm('max') < 1e-4
    assert (U(1.5, 0.5) - propagator(H, 1, [a])).norm('max') < 1e-4
    # Within tol, should use the precomupted value at U(0.5)
    assert (U(0.5) - U(0.5 + 1e-6)).norm('max') < 1e-10


def func(t):
    return np.cos(t)


def testPropHOTd():
    "Propagator: func td format"
    a = destroy(5)
    H = a.dag()*a
    Htd = [H, [H, func]]
    U = propagator(Htd, 1)
    ts = np.linspace(0, 1, 101)
    U2 = (-1j * H * trapezoid(1 + func(ts), ts)).expm()
    assert (U - U2).norm('max') < 1e-4


def testPropHOTd():
    "Propagator: func array td format + open"
    a = destroy(5)
    H = a.dag()*a
    ts = np.linspace(-0.01, 1.01, 103)
    coeffs = np.cos(ts)
    Htd = [H, [H, coeffs]]
    rho_0 = rand_dm(5)
    rho_1_prop = propagator(Htd, 1, c_ops=[a], tlist=ts)(rho_0)
    rho_1_me = mesolve(QobjEvo(Htd, tlist=ts), rho_0, [0, 1], [a]).final_state

    assert (rho_1_prop - rho_1_me).norm('max') < 1e-4


def testPropObjTd():
    a = destroy(5)
    H = a.dag()*a
    U = Propagator([H, [H, "w*t"]], c_ops=[a], args={'w': 1})
    assert (
        U(1) - propagator([H, [H, "w*t"]], 1, [a], args={'w': 1})
    ).norm('max') < 1e-4
    assert (
        U(0.5, w=2) - propagator([H, [H, "w*t"]], 0.5, [a], args={'w': 2})
    ).norm('max') < 1e-4
    assert (
        U(1.5, 0.5, w=1.5)
        - propagator([H, [H, "w*t"]], [0.5, 1.5], [a], args={'w': 1.5})[1]
    ).norm('max') < 1e-4


def testPropHOSteady():
    "Propagator: steady state"
    a = destroy(5)
    H = a.dag()*a
    c_op_list = []
    kappa = 0.1
    n_th = 2
    rate = kappa * (1 + n_th)
    c_op_list.append(np.sqrt(rate) * a)
    rate = kappa * n_th
    c_op_list.append(np.sqrt(rate) * a.dag())
    U = propagator(H, 2*np.pi, c_op_list)
    rho_prop = propagator_steadystate(U)
    rho_ss = steadystate(H, c_op_list)
    assert (rho_prop - rho_ss).norm('max') < 1e-4


@pytest.mark.parametrize("H", [
    pytest.param(tensor([qeye(2), qeye(2)]), id="tensor"),
    pytest.param(enr_identity([2, 2], 1), id="enr"),
])
def testPropHDims(H):
    "Propagator: preserve H dims (unitary_mode='single', parallel=False)"
    H = tensor([qeye(2), qeye(2)])
    U = propagator(H, 1)
    assert U._dims == H._dims


@pytest.mark.parametrize("L", [
    pytest.param(
        liouvillian(qeye(2) & qeye(2), [destroy(2) & destroy(2)]),
        id="tensor"
    ),
    pytest.param(
        liouvillian(enr_identity([2, 2], 1), list(enr_destroy([2, 2], 1))),
        id="enr"
    ),
])
def testPropHSuper(L):
    "Propagator: preserve super_oper dims"
    U = propagator(L, 1)
    assert U._dims == L._dims


def testPropEvo():
    a = destroy(5)
    H = a.dag()*a
    U = Propagator([H, [a + a.dag(), "w*t"]], args={'w': 1})
    psi = QobjEvo(U) @ basis(5, 4)
    tlist = np.linspace(0, 1, 6)
    psi_expected = sesolve(
        [H, [a + a.dag(), "w*t"]], basis(5, 4), tlist=tlist, args={'w': 1}
    ).states
    for t, psi_t in zip(tlist, psi_expected):
        assert abs(psi(t).overlap(psi_t)) > 1-1e-6


def _make_se(H, a):
    return SESolver(H)


def _make_me(H, a):
    return MESolver(H, [a])


def _make_br(H, a):
    spectra = qutip.coefficient(lambda t, w: w >= 0, args={"w": 0})
    return BRSolver(H, [(a+a.dag(), spectra)])


@pytest.mark.parametrize('solver', [
    pytest.param(_make_se, id='SESolver'),
    pytest.param(_make_me, id='MESolver'),
    pytest.param(_make_br, id='BRSolver'),
])
def testPropSolver(solver):
    a = destroy(5)
    H = a.dag()*a
    U = Propagator(solver(H, a))
    c_ops = []
    if solver is not _make_se:
        c_ops = [a]

    assert (U(1) - propagator(H, 1, c_ops)).norm('max') < 1e-4
    assert (U(0.5) - propagator(H, 0.5, c_ops)).norm('max') < 1e-4
    assert (U(1.5, 0.5) - propagator(H, 1, c_ops)).norm('max') < 1e-4


def testPropMCSolver():
    a = destroy(5)
    H = a.dag()*a
    solver = MCSolver(H, [a])
    with pytest.raises(TypeError) as err:
        Propagator(solver)
    assert str(err.value).startswith("Non-deterministic")


def testPropPiecewiseConst():
    H0 = qutip.sigmaz()
    H1 = qutip.sigmax()

    def H_func(t, args):
        return H0 if t < 0.5 else H1

    U = propagator(H_func, 2, piecewise_t=[0.5])
    expected = (-1j * H1 * 1.5).expm() * (-1j * H0 * 0.5).expm()
    assert (U - expected).norm('max') < 1e-4


def testPropPiecewiseConstantH():
    H0 = qutip.sigmaz()

    def H_func(t, args):
        return H0

    U = propagator(H_func, 2, piecewise_t=[0.5, 1.0])
    expected = (-1j * H0 * 2).expm()
    assert (U - expected).norm('max') < 1e-4


def testPropPiecewiseSingleCop():
    H0 = qutip.sigmaz()
    H1 = qutip.sigmax()
    a = destroy(2)

    def H_func(t, args):
        return H0 if t < 1.0 else H1

    U = propagator(H_func, 2, piecewise_t=[1.0], c_ops=a)
    U2 = propagator(H_func, 2, c_ops=a)
    assert (U - U2).norm('max') < 1e-4


def testPropPiecewiseListCops():
    H0 = qutip.sigmaz()
    H1 = qutip.sigmax()
    a = destroy(2)

    def H_func(t, args):
        return H0 if t < 0.75 else H1

    kappa = 0.1
    c_ops = [np.sqrt(kappa) * a, np.sqrt(kappa * 0.5) * a.dag()]

    U = propagator(H_func, 2, piecewise_t=[0.75], c_ops=c_ops)
    U2 = propagator(H_func, 2, c_ops=c_ops)
    assert (U - U2).norm('max') < 1e-4


def testPropPiecewiseSuperoperator():
    H0 = qutip.sigmaz()
    H1 = qutip.sigmax()
    a = destroy(2)

    L0 = liouvillian(H0, [a])
    L1 = liouvillian(H1, [a])

    def L_func(t, args):
        return L0 if t < 1.0 else L1

    U = propagator(L_func, 2, piecewise_t=[1.0])
    U2 = propagator(L_func, 2)
    expected = (L1 * 1.0).expm() * (L0 * 1.0).expm()
    assert (U - expected).norm('max') < 1e-4
    assert (U - U2).norm('max') < 1e-4


def testPropPiecewiseMultipleTimes():
    H0 = qutip.sigmaz()
    H1 = qutip.sigmax()
    H2 = qutip.sigmay()

    def H_func(t, args):
        if t < 0.5:
            return H0
        elif t < 1.5:
            return H1
        else:
            return H2

    U = propagator(H_func, 2.5, piecewise_t=[0.5, 1.5])
    expected = (
        (-1j * H2 * 1.0).expm() *
        (-1j * H1 * 1.0).expm() *
        (-1j * H0 * 0.5).expm()
    )
    assert (U - expected).norm('max') < 1e-4


def testPropPiecewiseListOutput():
    H0 = qutip.sigmaz()
    H1 = qutip.sigmax()
    H2 = qutip.sigmay()
    H3 = qutip.sigmaz() + qutip.sigmax()

    def H_func(t, args):
        if t < 0.5:
            return H0
        elif t < 1.0:
            return H1
        elif t < 1.5:
            return H2
        else:
            return H3

    tlist = [0, 0.5, 1.0, 1.5, 2.0]
    U_list = propagator(H_func, tlist, piecewise_t=[0.5, 1.0, 1.5])
    U2_list = propagator(H_func, tlist)

    for U, U2 in zip(U_list, U2_list):
        assert (U - U2).norm('max') < 1e-4


def testPropPiecewiseBoundaryConsistency():
    # Test that different boundary conditions give same result
    H0 = qutip.sigmaz()
    H1 = qutip.sigmax()

    # Function with <= at boundary
    def H_func_leq(t, args):
        return H0 if t <= 1.0 else H1

    # Function with < at boundary
    def H_func_lt(t, args):
        return H0 if t < 1.0 else H1

    U_leq = propagator(H_func_leq, 2.0, piecewise_t=[1.0])
    U_lt = propagator(H_func_lt, 2.0, piecewise_t=[1.0])

    assert (U_leq - U_lt).norm('max') < 1e-10


def testPropPiecewiseTimeDependentCops():
    H0 = qutip.sigmaz()
    H1 = qutip.sigmax()
    a = destroy(2)

    def H_func(t, args):
        return H0 if t < 1.0 else H1

    # Time-dependent collapse operators (piecewise constant)
    def c_func(t, args):
        kappa = 0.1 if t < 1.0 else 0.2
        return np.sqrt(kappa) * a

    # Should work with time-dependent collapse operators
    U = propagator(H_func, 2.0, c_ops=c_func, piecewise_t=[1.0])
    U2 = propagator(H_func, 2.0, c_ops=c_func)

    assert (U - U2).norm('max') < 1e-4


@pytest.mark.parametrize("n_points", [4, 10, 11])
def testPropPiecewiseUniformGrid(n_points):
    """piecewise optimization with uniform time grid"""
    H0 = qutip.sigmaz()
    H1 = qutip.sigmax()

    def H_func(t, args):
        return H0 if t < 1.0 else H1

    tlist = np.linspace(0, 2.0, n_points)

    U_list_pw = propagator(H_func, tlist, piecewise_t=[1.0])
    U_list_reg = propagator(H_func, tlist)

    assert len(U_list_pw) == n_points

    for U_pw, U_reg in zip(U_list_pw, U_list_reg):
        assert (U_pw - U_reg).norm('max') < 1e-4


@pytest.mark.parametrize("n_points", [11, 12])
def testPropPiecewiseCachingOptimization(n_points):
    """verify exponential caching reduces expm calls"""
    H0 = qutip.sigmaz()

    def H_func(t, args):
        return H0

    original_expm = qutip.Qobj.expm

    call_count = [0]

    def counted_expm(self):
        call_count[0] += 1
        return original_expm(self)

    with patch.object(qutip.Qobj, 'expm', counted_expm):
        tlist = np.linspace(0, 2.0, n_points)

        U_list = propagator(H_func, tlist, piecewise_t=[0.5, 1.0, 1.5])
        assert call_count[0] <= 10

        U_list_reg = propagator(H_func, tlist)

        assert len(U_list) == len(U_list_reg)
        for U_pw, U_reg in zip(U_list, U_list_reg):
            assert (U_pw - U_reg).norm('max') < 1e-4
